SAM 2 Annotation Tool¶

In this notebook, I walk through a user-friendly tool I created that allows you to accurately label a video for object tracking tasks.¶

The tool annotates the video by passing it through Meta's SAM 2 model and allowing a human-in-the-loop to correct its mistakes. SAM 2 is specifically designed for such a use case, as it is a promptable visual segmentation (PVS) model. Thus, before any object can be tracked, it must be identified in a given frame with a point(s), a bounding box, or a mask. After the initial prompt, SAM 2 will then track the object(s) throughout the video. If a given masklet is lost (e.g., from an occlusion), SAM 2 will require a new prompt in order to regain it.¶

SAM 2's transformer-based architecture learns both motion- and appearance-based features, outperforming many of the top existing tracker models. Its promptable nature also makes it especially well-suited for providing initial high-fidelity labels that can be further refined with just a few clicks.¶

Steps
1) Load in SAM 2 and the video
2) Set the initial prompt
3) Run inference with SAM 2
4) Find the frame(s) with an object whose masklet was lost
5) Re-label said frame(s) to regain the masklet
6) Re-run inference with the correction
7) Output the final labeled data
In [6]:
import os
import shutil
import math
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from matplotlib.widgets import Slider, Cursor, Button, TextBox
from matplotlib.patches import Rectangle
import ipywidgets as widgets
from IPython.display import display, Image

Set up the environment¶

In [3]:
%%capture
!git clone https://github.com/facebookresearch/segment-anything-2.git
os.chdir('/segment-anything-2')
!pip install -e .
!./checkpoints/download_ckpts.sh
!python setup.py clean
!python setup.py build_ext --inplace
In [77]:
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

Load In SAM 2 and the Video¶

First load in an instance of the SAM 2 predictor, choosing from its "tiny", "small", or "large" versions.¶

In [78]:
from sam2.build_sam import build_sam2_video_predictor

model_size = "tiny" # Set to: 'tiny', 'small', or 'large'

sam2_checkpoint = f"sam2_hiera_{model_size}.pt"
model_cfg = f"sam2_hiera_{model_size[0]}.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

The functions below allow us to visualize SAM 2's prompts and output. More specifically, they annotate the video frames with the user-selected points, the resulting segmented mask, and the implied bounding box.¶

In [79]:
# Function to draw a mask (and its implied bounding box)
def show_mask(mask, ax, obj_id=None, random_color=False):
    # Generate a color for the mask
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    # Plot the mask
    h, w = mask.shape[-2:]
    if mask is not None and np.any(mask):
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        # Plot the bounding box
        x_min, y_min, x_max, y_max = mask_to_bb(mask)
        rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        ax.imshow(mask_image)

# Function to display the selected points for labeling
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# Function to derive a bounding box from a mask
def mask_to_bb(mask):
    if mask is not None and np.any(mask):
        rows, cols = np.where(mask.squeeze())
        y_min, y_max = rows.min(), rows.max()
        x_min, x_max = cols.min(), cols.max()
        xyxy = (x_min, y_min, x_max, y_max)
        return xyxy

Store the video as a list of JPEG frames with filenames like <frame_index>.jpg.¶

In [80]:
def load_video(input_video, output_folder, reshape_scale=1.0):
    # Ensure the output folder exists
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    # Open the video file
    cap = cv2.VideoCapture(input_video)
    
    # Check if the video opened successfully
    if not cap.isOpened():
        raise ValueError("Error opening video file")
    
    # Get video properties
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Loop through all frames
    frame_number = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
    
        # Size down
        height, width = frame.shape[:2]
        new_dim = (int(width * reshape_scale), int(height * reshape_scale))
        resized_frame = cv2.resize(frame, new_dim, interpolation=cv2.INTER_AREA)
        
        # Construct filename with leading zeros
        filename = os.path.join(output_folder, f'{frame_number:05d}.jpg')
        
        # Save the frame as an image
        cv2.imwrite(filename, resized_frame)
        
        frame_number += 1
    
    # Release the video capture object
    cap.release()
    
    print("Frames have been extracted and saved.")
In [82]:
# Define input and output paths
input_video = 'marshawn.mp4'
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = 'marshawn_frames'

# Load in the video for SAM2 processing
load_video(input_video, video_dir, 1)
In [83]:
# Scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

Initialize the inference state with the video frames, from which SAM 2 will run segmentation inference.¶

In [ ]:
inference_state = predictor.init_state(video_path=video_dir)

Set with the initial prompts¶

The function below creates a user-friendly interactive window for selecting the points on the object(s) that you want SAM 2 to track¶

In [118]:
def annotate(frame_idx):
    # Initialize data structure
    global data
    data = []

    # Load the image
    image_path = os.path.join(video_dir, frame_names[frame_idx])
    img = Image.open(image_path)

    # Set up the plot
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.set_title(f"Select Points and Assign Labels: Frame {frame_idx}")
    im = ax.imshow(img)
    ax.set_xlim(0, img.width)
    ax.set_ylim(img.height, 0)  # Invert y-axis to match image coordinates
    
    if 'video_segments' in globals() and frame_idx in video_segments:
        for out_obj_id, out_mask in video_segments[frame_idx].items():
            if out_mask is not None and np.any(out_mask):
                show_mask(out_mask, ax, obj_id=out_obj_id)

    # Initialize lists to store points and labels for the current object
    global current_points, current_labels, current_obj_id, current_label
    current_points = []
    current_labels = []
    current_obj_id = 1
    current_label = 1

    # Function to add points and plot them
    def on_click(event):
        if event.inaxes == ax:
            x, y = int(event.xdata), int(event.ydata)
            current_points.append([x, y])
            current_labels.append(current_label)
            ax.text(x, y, str(current_obj_id), color='blue' if current_label == 1 else 'red', fontsize=12, ha='center')
            fig.canvas.draw()
            print(f"Point added: ({x}, {y}) with label {current_label} for object {current_obj_id}")

            # Update data with each click
            existing_obj = next((item for item in data if item[0] == current_obj_id), None)
            if existing_obj:
                existing_obj[1].append([x, y])
                existing_obj[2].append(current_label)
            else:
                data.append((current_obj_id, [[x, y]], [current_label]))
            print(f"Data updated for object {current_obj_id}: {data}")

    # Connect the click event to the callback function
    cid = fig.canvas.mpl_connect('button_press_event', on_click)

    # Function to update labels
    def update_label(change):
        global current_label
        current_label = int(change['new'])
        print(f"Current label set to {current_label}")

    # Create dropdown for label selection using ipywidgets
    label_dropdown = widgets.Dropdown(
        options=[1, 0],
        value=current_label,
        description='Label:',
    )

    label_dropdown.observe(update_label, names='value')
    display(label_dropdown)

    # Function to update the object ID
    def update_obj_id(change):
        global current_obj_id
        save_current_object_data()
        current_obj_id = int(change['new'])
        print(f"Switched to object ID {current_obj_id}")

    # Function to save current object data
    def save_current_object_data():
        global current_points, current_labels, data, current_obj_id
        if current_points and current_labels:
            existing_obj = next((item for item in data if item[0] == current_obj_id), None)
            if existing_obj:
                existing_obj[1].extend(current_points)
                existing_obj[2].extend(current_labels)
            else:
                data.append((current_obj_id, current_points.copy(), current_labels.copy()))
            print(f"Data saved for object {current_obj_id}: {data}")
            current_points.clear()
            current_labels.clear()

    # Create dropdown for object ID selection using ipywidgets
    object_id_dropdown = widgets.Dropdown(
        options=[i for i in range(1, 51)],
        value=current_obj_id,
        description='Object ID:',
    )

    object_id_dropdown.observe(update_obj_id, names='value')
    display(object_id_dropdown)

    plt.show()
The two dropdowns allow the user to:
1) Select the Label of the clicked point (1 = the point is on Object 1; 1 = the point is not on Object 1)
2) Select the Object ID to distinguish multiple objects/masks from one another throughout the video. Here, since only one object is segmented, the ID is always 1.
We use this window to positively label points on Marshawn Lynch and negatively label points on the defender next to him in order to get an accurate mask for SAM 2 to propagate.
In [3]:
frame_idx = 0
annotate(frame_idx)
No description has been provided for this image

These functions transform the selected points into the prompt structure with which SAM 2's predictor object is updated.¶

In [123]:
# Function to structure the selected points into prompts for SAM 2
def make_prompts(data: list):
    """
    Inputs:
    - data (list of tuples): data on the objects to be tracked, with each tuple formatted as
                  (object_id, [[x1, y1], [x2, y2], ...], [label1, label2, ...])
    Outputs:
    - prompts: a dict with all the visual prompt information for SAM2
    """
    prompts = {}
    for obj_id, points, labels in data:
        prompts[obj_id] = (
            np.array(points, dtype=np.float32),
            np.array(labels, np.int32)
        )
    return prompts

# Function to add the prompts to the SAM 2 predictor
def add_prompts(prompts, frame_idx, inference_state, is_refinement=False):
    if is_refinement:
        predictor.reset_state(inference_state)
    # Iterate over each object in the prompts dictionary
    for obj_id, (points, labels) in prompts.items():
        # Call the function to add new points for each object
        _, out_obj_ids, out_mask_logits = predictor.add_new_points(
            inference_state=inference_state,
            frame_idx=frame_idx,
            obj_id=obj_id,
            points=points,
            labels=labels
        )
    return _, out_obj_ids, out_mask_logits
In [125]:
# Process the user-selected labeled points
prompts = make_prompts(data)
_, out_obj_ids, out_mask_logits = add_prompts(prompts, frame_idx, inference_state)
In [127]:
# Show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    mask = (out_mask_logits[i] > 0.0).cpu().numpy()
    show_mask(mask, plt.gca(), obj_id=out_obj_id)
Figure
No description has been provided for this image

We have now created an accurate mask with SAM 2 and are ready to track it throughout the video!¶

Run Inference¶

In [128]:
# Function to propagate the masks throughout the video
def propagate_masks(inference_state, video_segments={}, start_frame_idx=0):
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
    return video_segments
In [129]:
video_segments = propagate_masks(inference_state, start_frame_idx=15)
propagate in video: 100%|██████████| 239/239 [00:09<00:00, 26.35it/s]
In [130]:
def view_labeled_frames(frame_stride, frame_names, video_segments, video_dir, cols=4):
    plt.close("all")
    frames_to_display = list(range(0, len(frame_names), frame_stride))
    rows = math.ceil(len(frames_to_display) / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    axes = axes.flatten() if rows > 1 else [axes]

    for ax, out_frame_idx in zip(axes, frames_to_display):
        ax.set_title(f"frame {out_frame_idx}")
        ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
        if out_frame_idx in video_segments:
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                if out_mask is not None and np.any(out_mask):
                    show_mask(out_mask, ax, obj_id=out_obj_id)
        ax.axis('off')

    # Hide any unused subplots
    for ax in axes[len(frames_to_display):]:
        ax.axis('off')

    plt.subplots_adjust(wspace=0.1, hspace=0)
    plt.tight_layout(pad=0.5)
    display(fig)
    plt.close(fig)

View the Output and Find Mistakes¶

In [131]:
# Render the segmentation results every few frames
frame_stride = 15
view_labeled_frames(frame_stride, frame_names, video_segments, video_dir)
No description has been provided for this image

As seen in the frames above, SAM 2 does a great job of maintaining Marshawn Lynch's masklet, even as he runs through defenders. However, in frame 210, we notice that the mask is reduced to only the crown of his helmet, even though his legs are seen in the air. Because of this, by the time he emerges victorious in frame 225, the masklet has been completely lost.¶

The promptable nature of SAM 2 does not only allow us to instantiate masks, but also to refine its predictions at any point during the video. Therefore, we will provide the correct mask for frame 210 by selecting a new set of labeled points. Thus, SAM 2 can recalibrate the masklet that it will propagate and attain even higher accuracy.¶

Re-label Faulty Masks¶

In [5]:
frame_idx = 210
annotate(frame_idx)
No description has been provided for this image
In [134]:
prompts = make_prompts(data)
In [135]:
# View the mask before refinement
fig_before, ax_before = plt.subplots(figsize=(12, 8))
ax_before.set_title(f"Frame {frame_idx} -- Before Refinement")
ax_before.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for out_obj_id, out_mask in video_segments[frame_idx].items():
    if out_mask is not None and np.any(out_mask):
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
display(fig_before)
plt.close(fig_before)

# Give the new prompts to SAM 2
_, out_obj_ids, out_mask_logits = add_prompts(prompts, frame_idx, inference_state, is_refinement=True)

# View the mask after refinement
fig_after, ax_after = plt.subplots(figsize=(12, 8))
ax_after.set_title(f"Frame {frame_idx} -- After Refinement")
ax_after.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    mask = (out_mask_logits[i] > 0.0).cpu().numpy()
    show_mask(mask, plt.gca(), obj_id=out_obj_id)
display(fig_after)
plt.close(fig_after)
No description has been provided for this image
No description has been provided for this image

Now SAM 2 has captured a much better masklet for frame 210, which it will propagate henceforth.¶

Output Final Results¶

In [136]:
video_segments = propagate_masks(inference_state, video_segments, start_frame_idx=frame_idx)
propagate in video: 100%|██████████| 29/29 [00:01<00:00, 27.59it/s]
In [137]:
view_labeled_frames(15, frame_names, video_segments, video_dir)
No description has been provided for this image

Voila! Combining its unprecedented out-of-the-box performance with a simple course correction, SAM 2 is able to successfully track Marshawn Lynch throughout his turbulent trip to the endzone.¶

The constant motion, sudden changes in direction, and presence of other players all add to the complexity of the scene. The accomplishment of such a difficult–and previously intractable–task demonstrates the robustness of SAM 2 and the bright future of spatiotemporal vision models.¶